{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# AWQ on Vicuna"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook, we use Vicuna model to demonstrate the performance of AWQ on instruction-tuned models. We implement AWQ real-INT4 inference kernels, which are wrapped as Pytorch modules and can be easily used by existing models. We also provide a simple example to show how to use AWQ to quantize a model and save/load the quantized model checkpoint."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In order to run this notebook, you need to install the following packages:\n",
    "- [AWQ](https://github.com/mit-han-lab/llm-awq)\n",
    "- [Pytorch](https://pytorch.org/)\n",
    "- [Accelerate](https://github.com/huggingface/accelerate)\n",
    "- [Transformers](https://github.com/huggingface/transformers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from accelerate import init_empty_weights, load_checkpoint_and_dispatch\n",
    "from awq.quantize.quantizer import real_quantize_model_weight\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig\n",
    "from tinychat.demo import gen_params, stream_output\n",
    "from tinychat.stream_generators import StreamGenerator\n",
    "from tinychat.modules import make_quant_norm, make_quant_attn, make_fused_mlp\n",
    "from tinychat.utils.prompt_templates import get_prompter\n",
    "import os\n",
    "# This demo only support single GPU for now\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Please get the Vicuna model from [FastChat](https://github.com/lm-sys/FastChat) and run the following command to generate a quantized model checkpoint first.\n",
    "\n",
    "```bash\n",
    "mkdir quant_cache\n",
    "python -m awq.entry --model_path [vicuna-7b_model_path] \\\n",
    "    --w_bit 4 --q_group_size 128 \\\n",
    "    --load_awq awq_cache/vicuna-7b-w4-g128.pt \\\n",
    "    --q_backend real --dump_quant quant_cache/vicuna-7b-w4-g128-awq.pt\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# model_path = \"\" # the path of vicuna-7b model\n",
    "# load_quant_path = \"quant_cache/vicuna-7b-w4-g128-awq.pt\"\n",
    "model_path = \"/data/llm/checkpoints/vicuna-hf/vicuna-7b\"\n",
    "load_quant_path = \"/data/llm/checkpoints/vicuna-hf/vicuna-7b-awq-w4g128.pt\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We first load a empty model and replace all the linear layers with WQLinear layers. Then we load the quantized weights from the checkpoint. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8b79a82b73ab4d9191ba54f5d0f8cb86",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "real weight quantization...(init only): 100%|███████████████████| 32/32 [00:11<00:00,  2.69it/s]\n",
      "The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.\n",
      "The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.\n"
     ]
    }
   ],
   "source": [
    "config = AutoConfig.from_pretrained(model_path)\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)\n",
    "with init_empty_weights():\n",
    "    model = AutoModelForCausalLM.from_pretrained(model_path, config=config,\n",
    "                                                    torch_dtype=torch.float16)\n",
    "q_config = {\"zero_point\": True, \"q_group_size\": 128}\n",
    "real_quantize_model_weight(\n",
    "    model, w_bit=4, q_config=q_config, init_only=True)\n",
    "\n",
    "model = load_checkpoint_and_dispatch(\n",
    "    model, load_quant_path,\n",
    "    device_map=\"auto\",\n",
    "    no_split_module_classes=[\"LlamaDecoderLayer\"]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Warning] Calling a fake MLP fusion. But still faster than Huggingface Implimentation.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "LlamaForCausalLM(\n",
       "  (model): LlamaModel(\n",
       "    (embed_tokens): Embedding(32000, 4096, padding_idx=0)\n",
       "    (layers): ModuleList(\n",
       "      (0-31): 32 x LlamaDecoderLayer(\n",
       "        (self_attn): QuantLlamaAttention(\n",
       "          (qkv_proj): WQLinear(in_features=4096, out_features=12288, bias=False, w_bit=4, group_size=128)\n",
       "          (o_proj): WQLinear(in_features=4096, out_features=4096, bias=False, w_bit=4, group_size=128)\n",
       "          (rotary_emb): QuantLlamaRotaryEmbedding()\n",
       "        )\n",
       "        (mlp): QuantLlamaMLP(\n",
       "          (down_proj): WQLinear(in_features=11008, out_features=4096, bias=False, w_bit=4, group_size=128)\n",
       "        )\n",
       "        (input_layernorm): FTLlamaRMSNorm()\n",
       "        (post_attention_layernorm): FTLlamaRMSNorm()\n",
       "      )\n",
       "    )\n",
       "    (norm): FTLlamaRMSNorm()\n",
       "  )\n",
       "  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)\n",
       ")"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "make_quant_attn(model, \"cuda:0\")\n",
    "make_quant_norm(model)\n",
    "make_fused_mlp(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdin",
     "output_type": "stream",
     "text": [
      "USER:  Show me some attractions in Boston.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ASSISTANT: 1. Boston Public Library\n",
      "2. Fenway Park\n",
      "3. Harvard Square\n",
      "4. Boston Common\n",
      "5. Freedom Trail\n",
      "6. Museum of Fine Arts\n",
      "7. Isabella Stewart Gardner Museum\n",
      "8. Paul Revere House\n",
      "9. New England Aquarium\n",
      "10. Museum of Science\n",
      "==================================================\n",
      "Speed of Inference\n",
      "--------------------------------------------------\n",
      "Context Stage    : 7.18 ms/token\n",
      "Generation Stage : 9.49 ms/token\n",
      "Average Speed    : 8.53 ms/token\n",
      "==================================================\n"
     ]
    },
    {
     "name": "stdin",
     "output_type": "stream",
     "text": [
      "USER:  \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EXIT...\n"
     ]
    }
   ],
   "source": [
    "model_prompter = get_prompter(\"llama\", model_path)\n",
    "stream_generator = StreamGenerator\n",
    "count = 0\n",
    "while True:\n",
    "    # Get input from the user\n",
    "    input_prompt = input(\"USER: \")\n",
    "    if input_prompt == \"\":\n",
    "        print(\"EXIT...\")\n",
    "        break\n",
    "    model_prompter.insert_prompt(input_prompt)\n",
    "    output_stream = stream_generator(model, tokenizer, model_prompter.model_input, gen_params, device=\"cuda:0\")\n",
    "    outputs = stream_output(output_stream)    \n",
    "    model_prompter.update_template(outputs)\n",
    "    count += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (awq)",
   "language": "python",
   "name": "awq"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
